{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "from keras.callbacks import LambdaCallback\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Activation\n",
    "from keras.layers import LSTM\n",
    "from keras.optimizers import RMSprop, Adam\n",
    "from keras.utils.data_utils import get_file\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import sys\n",
    "import io\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "songs = pd.read_csv('data/drake-songs.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "367372"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = ''\n",
    "\n",
    "for index, row in songs['lyrics'].iteritems():\n",
    "    cleaned = str(row).lower().replace(' ', '\\n')\n",
    "    text = text + \" \".join(re.findall(r\"[a-z']+\", cleaned))\n",
    "    \n",
    "len(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total chars: 28\n"
     ]
    }
   ],
   "source": [
    "import re\n",
    "\n",
    "tokens = re.findall(r\"[a-z'\\s]\", text)\n",
    "\n",
    "chars = sorted(list(set(tokens)))\n",
    "print('total chars:', len(chars))\n",
    "char_indices = dict((c, i) for i, c in enumerate(chars))\n",
    "indices_char = dict((i, c) for i, c in enumerate(chars))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[' ',\n",
       " \"'\",\n",
       " 'a',\n",
       " 'b',\n",
       " 'c',\n",
       " 'd',\n",
       " 'e',\n",
       " 'f',\n",
       " 'g',\n",
       " 'h',\n",
       " 'i',\n",
       " 'j',\n",
       " 'k',\n",
       " 'l',\n",
       " 'm',\n",
       " 'n',\n",
       " 'o',\n",
       " 'p',\n",
       " 'q',\n",
       " 'r',\n",
       " 's',\n",
       " 't',\n",
       " 'u',\n",
       " 'v',\n",
       " 'w',\n",
       " 'x',\n",
       " 'y',\n",
       " 'z']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nb sequences: 122444\n"
     ]
    }
   ],
   "source": [
    "# cut the text in semi-redundant sequences of maxlen characters\n",
    "maxlen = 40\n",
    "step = 3\n",
    "sentences = []\n",
    "next_chars = []\n",
    "\n",
    "for i in range(0, len(text) - maxlen, step):\n",
    "    sentences.append(text[i: i + maxlen])\n",
    "    next_chars.append(text[i + maxlen])\n",
    "    \n",
    "print('nb sequences:', len(sentences))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vectorization...\n"
     ]
    }
   ],
   "source": [
    "print('Vectorization...')\n",
    "\n",
    "x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)\n",
    "y = np.zeros((len(sentences), len(chars)), dtype=np.bool)\n",
    "\n",
    "for i, sentence in enumerate(sentences):\n",
    "    for t, char in enumerate(sentence):\n",
    "        x[i, t, char_indices[char]] = 1\n",
    "    y[i, char_indices[next_chars[i]]] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Build model...\n"
     ]
    }
   ],
   "source": [
    "# build the model: a single LSTM\n",
    "print('Build model...')\n",
    "model = Sequential()\n",
    "model.add(LSTM(128, input_shape=(maxlen, len(chars))))\n",
    "model.add(Dense(len(chars)))\n",
    "model.add(Activation('softmax'))\n",
    "\n",
    "model.compile(loss='categorical_crossentropy', optimizer=RMSprop(lr=0.01))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "lstm_1 (LSTM)                (None, 128)               80384     \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 28)                3612      \n",
      "_________________________________________________________________\n",
      "activation_1 (Activation)    (None, 28)                0         \n",
      "=================================================================\n",
      "Total params: 83,996\n",
      "Trainable params: 83,996\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(preds, temperature=1.0):\n",
    "    preds = np.asarray(preds).astype('float64')\n",
    "    preds = np.log(preds) / temperature\n",
    "    exp_preds = np.exp(preds)\n",
    "    preds = exp_preds / np.sum(exp_preds)\n",
    "    probas = np.random.multinomial(1, preds, 1)\n",
    "    return np.argmax(probas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def on_epoch_end(epoch, logs):\n",
    "    print('----- Generating text after Epoch: %d' % epoch)\n",
    "\n",
    "    start_index = random.randint(0, len(text) - maxlen - 1)\n",
    "    for diversity in [0.2, 0.5]:\n",
    "        print('----- diversity:', diversity)\n",
    "\n",
    "        generated = ''\n",
    "        sentence = text[start_index: start_index + maxlen]\n",
    "        generated += sentence\n",
    "        print('----- Generating with seed: \"' + sentence + '\"')\n",
    "        sys.stdout.write(generated)\n",
    "\n",
    "        for i in range(400):\n",
    "            x_pred = np.zeros((1, maxlen, len(chars)))\n",
    "            for t, char in enumerate(sentence):\n",
    "                x_pred[0, t, char_indices[char]] = 1.\n",
    "\n",
    "            preds = model.predict(x_pred, verbose=0)[0]\n",
    "            next_index = sample(preds, diversity)\n",
    "            next_char = indices_char[next_index]\n",
    "\n",
    "            generated += next_char\n",
    "            sentence = sentence[1:] + next_char\n",
    "\n",
    "            sys.stdout.write(next_char)\n",
    "            sys.stdout.flush()\n",
    "        print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      " 52992/122444 [===========>..................] - ETA: 50s - loss: 1.2885"
     ]
    }
   ],
   "source": [
    "print_callback = LambdaCallback(on_epoch_end=on_epoch_end)\n",
    "\n",
    "history = model.fit(\n",
    "    x, \n",
    "    y,\n",
    "    batch_size=128,\n",
    "    epochs=10,\n",
    "    callbacks=[print_callback]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'history' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-16-0e6f4ac90437>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_line_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'matplotlib'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'inline'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'history' is not defined"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "plt.plot(history.history['loss'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_output():\n",
    "    generated = ''\n",
    "    usr_input = input(\"Write the beginning of your poem, the Drake machine will complete it. Your input is: \")\n",
    "\n",
    "    sentence = ('{0:0>' + str(Tx) + '}').format(usr_input).lower()\n",
    "    generated += usr_input \n",
    "\n",
    "    sys.stdout.write(\"\\n\\nHere is your poem: \\n\\n\") \n",
    "    sys.stdout.write(usr_input)\n",
    "    for i in range(400):\n",
    "\n",
    "        x_pred = np.zeros((1, Tx, len(chars)))\n",
    "\n",
    "        for t, char in enumerate(sentence):\n",
    "            if char != '0':\n",
    "                x_pred[0, t, char_indices[char]] = 1.\n",
    "\n",
    "        preds = model.predict(x_pred, verbose=0)[0]\n",
    "        next_index = sample(preds, temperature = 0.2)\n",
    "        next_char = indices_char[next_index]\n",
    "\n",
    "        generated += next_char\n",
    "        sentence = sentence[1:] + next_char\n",
    "\n",
    "        sys.stdout.write(next_char)\n",
    "        sys.stdout.flush()\n",
    "\n",
    "        if next_char == '\\n':\n",
    "            continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Write the beginning of your poem, the Drake machine will complete it. Your input is: guess it is what we make it i'm tired of\n",
      "\n",
      "\n",
      "Here is your poem: \n",
      "\n",
      "guess it is what we make it i'm tired of me i started to be the money the guing the money that shit i swear the stre the fuck it i get it all my thing that i was a grammy i got me and i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i get it i ge"
     ]
    }
   ],
   "source": [
    "Tx = 40\n",
    "generate_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.save_model_weights(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
